package org.xiaoweige.mybatis.interceptor;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.cache.impl.PerpetualCache;
import org.apache.ibatis.executor.BaseExecutor;
import org.apache.ibatis.executor.CachingExecutor;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.convert.ConversionService;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
import org.xiaoweige.mybatis.annotation.CryptService;
import org.xiaoweige.mybatis.annotation.EnableCipher;
import org.xiaoweige.mybatis.annotation.Encrypted;
import org.xiaoweige.mybatis.config.BeanFactoryHolder;


/**
 * 字段加解密拦截器
 *
 * @author Jerry.hu
 * @summary 字段加解密拦截器
 * @Copyright (c) 2018, xiaoweige Group All Rights Reserved.
 * @Description 字段加解密拦截器
 * @since 2018-06-09 18:09
 */
@Intercepts({
        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})
})
@SuppressWarnings("unchecked")
public class FieldEncryptInterceptor implements Interceptor {

    private static final Logger LOGGER = LoggerFactory.getLogger(FieldEncryptInterceptor.class);
    private ThreadLocal<Map<Object, Map<Field, String>>> originalFieldMapLocal = ThreadLocal
            .withInitial(ConcurrentHashMap::new);
    private static final int EXECUTOR_PARAMETER_COUNT_4 = 4;
    private static final int MAPPED_STATEMENT_INDEX = 0;
    private static final int PARAMETER_INDEX = 1;
    private static final int ROW_BOUNDS_INDEX = 2;
    private static final int CACHE_KEY_INDEX = 4;
    private static final int BOUND_SQL_INDEX = 5;
    private static final String SELECT_KEY ="selectKey";
    private String sqlId ="";
    private static  String CONSTANT_CIPHER_TEXT;
    private static final String CONSTANT_VALUE = "0.00";
    private static final List<String> CONSTANT_VALUES = Arrays.asList("0,","0.0","0.00","0.000");

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object[] args = invocation.getArgs();
        EnableCipher enableCipher;
        boolean cache =  Boolean.FALSE;
        MappedStatement ms = (MappedStatement) args[MAPPED_STATEMENT_INDEX];
        enableCipher = getEnableCipher(ms);
        Object parameter = args[PARAMETER_INDEX];
        final Map<Object, Map<Field, String>> originalFieldMap = originalFieldMapLocal.get();
        if (ms.getId().startsWith(sqlId) && ms.getId().endsWith(SELECT_KEY)) {
            restore(parameter);
            originalFieldMap.clear();
        }
        if(Objects.isNull(enableCipher)){
            return invocation.proceed();
        }else {
            SqlCommandType sqlCommandType = ms.getSqlCommandType();
            if (SqlCommandType.SELECT.equals(sqlCommandType)) {
                cache = isCache(invocation);
            } else {
                updateParameters(enableCipher, ms, parameter);
            }
            Object object = invocation.proceed();
            if (!CollectionUtils.isEmpty(originalFieldMap)) {
                restore(parameter);
                originalFieldMap.clear();
            }
            if (!cache) {
                if (SqlCommandType.SELECT.equals(sqlCommandType)) {
                    return decrypt(object);
                }
            }
            return object;
        }
    }

    /**
     * 修改入参信息
     * @author Jerry.hu
     * @modifier Jerry.hu
     * @since 2018-10-27 11:05:55
     * @param enableCipher 加解密实体
     * @param ms MappedStatement
     * @param parameter 入参参数
     */
    private void updateParameters(EnableCipher enableCipher, MappedStatement ms, Object parameter) {
        sqlId = ms.getId();
        boolean decrypt = enableCipher.value().equals(EnableCipher.CipherType.DECRYPT);
        if (!(parameter instanceof Map)) {
            encryptByAnnByList(Collections.singletonList(parameter), decrypt);
        } else {
            Map<String, Object> map = getParameterMap(parameter);
            map.forEach((k, v) -> {
                if (v instanceof Collection) {
                    encryptByAnnByList((List<Object>) v, decrypt);
                } else {
                    encryptByAnnByList(Collections.singletonList(v), decrypt);
                }
            });
        }
    }

    /**
     * 获取参数的map 集合
     * @author Jerry.hu
     * @modifier Jerry.hu
     * @since 2018-11-15 11:10:59
     * @param parameter 参数object
     * @return map 集合
     */
    private Map<String, Object> getParameterMap(Object parameter){
        Set<Integer> hashCodeSet = new HashSet<>();
        return ((Map<String, Object>) parameter).entrySet().stream().filter(e->Objects.nonNull(e.getValue())).filter(r -> {
            if (!hashCodeSet.contains(r.getValue().hashCode())) {
                hashCodeSet.add(r.getValue().hashCode());
                return true;
            }
            return false;
        }).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
    }

    /**
     * 恢复原始值，将内存中的密文回刷为明文
     * @author Jerry.hu
     * @modifier Jerry.hu
     * @since 2018-11-15 11:11:40
     * @param obz 操作对象
     */
    private void restore(Object  obz) {
        final Map<Object, Map<Field, String>> originalFieldMap = originalFieldMapLocal.get();
        if (!(obz instanceof Map)) {
            Map<Field, String> ori = originalFieldMap.get(obz);
            if (Objects.nonNull(ori)) {
                ori.forEach((field, k) -> ReflectionUtils.setField(field, obz, k));
            }
        } else {
            Map<String, Object> map = getParameterMap(obz);
            map.forEach((k, v) -> {
                if (v instanceof Collection) {
                    ((List<Object>) v).stream().filter(Objects::nonNull).forEach(
                            obj -> {
                                Map<Field, String> ori = originalFieldMap.get(obj);
                                if (Objects.nonNull(ori)) {
                                    ori.forEach((field, value) -> ReflectionUtils.setField(field, obj, value));
                                }
                            });
                } else {
                    Map<Field, String> ori = originalFieldMap.get(v);
                    if (Objects.nonNull(ori)) {
                        ori.forEach((field, value) -> ReflectionUtils.setField(field, v, value));
                    }
                }
            });
        }

    }

    /**
     * 判断是否有缓存信息
     * @author Jerry.hu
     * @modifier Jerry.hu
     * @since 2018-09-26 14:19:56
     * @param invocation 调用链
     * @return  true 有缓存 false 没有缓存
     */
    private boolean isCache(Invocation invocation) throws IllegalAccessException {
        Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[MAPPED_STATEMENT_INDEX];
        Object parameter = args[PARAMETER_INDEX];
        BoundSql bs;
        CacheKey cacheKey = null;
        if (args.length == EXECUTOR_PARAMETER_COUNT_4) {
            //4 个参数时
            bs =  ms.getBoundSql(parameter);
        } else {
            //6 个参数时
            cacheKey = (CacheKey) args[CACHE_KEY_INDEX];
            bs = (BoundSql)args[BOUND_SQL_INDEX];
        }
        Object executor =  invocation.getTarget();
        Executor baseExecutor;
        if(executor instanceof CachingExecutor){
            Field field =  ReflectionUtils.findField(CachingExecutor.class, "delegate");
            assert field != null;
            field.setAccessible(true);
            baseExecutor =  (Executor) field.get(executor);
        }else{
            baseExecutor = (BaseExecutor) invocation.getTarget();
        }
        if (Objects.isNull(cacheKey)) {
            cacheKey = baseExecutor.createCacheKey(ms, parameter, (RowBounds) args[ROW_BOUNDS_INDEX], bs);
        }
        Field field = ReflectionUtils.findField(BaseExecutor.class, "localCache");
        assert field != null;
        field.setAccessible(true);
        PerpetualCache localCache = (PerpetualCache) field.get(baseExecutor);
        return Objects.nonNull(localCache.getObject(cacheKey));
    }

    /**
     * 解密
     * @author Jerry.hu
     * @modifier Jerry.hu
     * @since 2018-09-26 14:21:13
     * @param object 需要解密的数据源
     * @return Object
     */
    private Object decrypt(Object object){
        if (Objects.nonNull(object)) {
            if (object instanceof List) {
                encryptByAnnByList((List<Object>) object, true);
            } else {
                encryptByAnnByList(Collections.singletonList(object), true);
            }
        }
        return object;
    }

    private EnableCipher getEnableCipher(MappedStatement ms){
        String namespace = ms.getId();
        String className = namespace.substring(0,namespace.lastIndexOf("."));
        String methodName = ms.getId().substring(ms.getId().lastIndexOf(".")+1);
        Method[] mes;
        try {
            mes = Class.forName(className).getMethods();
            for(Method m : mes){
                if(m.getName().equals(methodName)){
                    if(m.isAnnotationPresent(EnableCipher.class)){
                        return m.getAnnotation(EnableCipher.class);
                    }
                }
            }
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        return null;
    }

    /**
     * 批量加解密操作
     *
     * @param list    需要加解密的实体集合
     * @param decrypt 是否解密，true 解密操作，false 加密操作
     * @author Jerry.hu
     * @modifier Jerry.hu
     * @since 2018-06-28 19:40:23
     */
    private void encryptByAnnByList(List<Object> list, boolean decrypt) {
        //获取字段上的注解值
        if (CollectionUtils.isEmpty(list)) {
            return;
        }
        Set<Field> annotationFields = this.getFields(list.get(0)).stream().filter(field -> field.isAnnotationPresent(Encrypted.class))
                .collect(Collectors.toSet());
        List<String> strings = list.stream().flatMap((obj) -> getFields(obj).stream()
                .filter(annotationFields::contains)
                .map(field -> this.getField(field,obj)))
                .filter(Objects::nonNull)
                .map(Object::toString)
                .collect(Collectors.toList());

        Map<String, String> map = IntStream.range(0, strings.size()).mapToObj(i -> new Tuple(i, strings.get(i)))
                .collect(Collectors.groupingBy(t -> t.index / 1000))
                .values().stream()
                .collect(HashMap::new,
                        (a, b) -> a.putAll(batchCipher(b, decrypt)),
                        HashMap::putAll
                );

        list.forEach(obj -> {
            Map<Field,String> fieldOriMap = new HashMap<>();

            getFields(obj).stream()
                    .filter(annotationFields::contains)
                    .forEach(field -> {
                        Object value = this.getField(field,obj);
                        if(!decrypt) {
                            fieldOriMap.put(field, String.valueOf(value));
                        }
                        if (Objects.nonNull(value)) {
                            ReflectionUtils.setField(field, obj, getConversionService().convert(map.get(value.toString()), field.getType()));
                        }
                    });

            if(!decrypt) {
                originalFieldMapLocal.get().put(obj, fieldOriMap);
            }
        });
    }

    static class Tuple {
        int index;
        String value;

        Tuple(int index, String value) {
            this.index = index;
            this.value = value;
        }
    }


    private List<Field> getFields(Object obj){
        List<Field> fieldList = new ArrayList<>() ;
        Class tempClass = obj.getClass();
        //当父类为null的时候说明到达了最上层的父类(Object类).
        while (tempClass != null) {
            fieldList.addAll(Arrays.asList(tempClass .getDeclaredFields()));
            //得到父类,然后赋给自己
            tempClass = tempClass.getSuperclass();
        }
        return fieldList;
    }


    /**
     * 获取字段
     * @author Jerry.hu
     * @modifier Jerry.hu
     * @since 2018-09-28 16:49:46
     * @param
     * @return
     */
    private Object getField(Field field,Object obj){
        ReflectionUtils.makeAccessible(field);
        return ReflectionUtils.getField(field, obj);
    }

    /**
     * 批量实现待加解密操作
     *
     * @param oriValues 原始字段
     * @param decrypt   true 解密 false 加密
     * @return 针对原始值返回对应的k, v 结果集
     * @author Jerry.hu
     * @modifier Jerry.hu
     * @since 2018-06-29 13:28:27
     */
    private Map<String, String> batchCipher(List<Tuple> oriValues, boolean decrypt) {
        Map<String,String> result = new HashMap<>();
        if (decrypt) {
            List<String>  result2 = oriValues.stream().filter(tuple -> {
                if(Objects.equals(tuple.value, CONSTANT_CIPHER_TEXT)){
                    result.put(tuple.value,CONSTANT_VALUE);
                    return false;
                }else {
                    return true;
                }
            }).map(tuple -> tuple.value).collect(Collectors.toList());
            if(CollectionUtils.isEmpty(result2)){
                return result;
            }else {
                result.putAll(this.getCryptService().batchDecrypt(result2));
            }

        } else {
            List<String>  result2 = oriValues.stream().filter(tuple -> {
                if (CONSTANT_VALUES.contains(tuple.value)) {
                    result.put(tuple.value, CONSTANT_CIPHER_TEXT);
                    return false;
                } else {
                    return true;
                }
            }).map(tuple -> tuple.value).collect(Collectors.toList());
            if(CollectionUtils.isEmpty(result2)) {
                return result;
            }else {
                result.putAll(this.getCryptService().batchEncrypt(result2));
            }
        }
        return result;
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
        CONSTANT_CIPHER_TEXT =  properties.getProperty("CONSTANT_CIPHER_TEXT");
        if(!StringUtils.hasLength(CONSTANT_CIPHER_TEXT)){
            LOGGER.warn("=================================================================================================================================\n" +
                    "=====================================当前拦截器未配置常量密文值 key:CONSTANT_CIPHER_TEXT======================================\n" +
                    "=====================================为了避免出现异常请在mybatis-config.xml 中设置常量密文key 和值=====================================\n" +
                    "=================================================================================================================================");
        }
    }

    /**
     * 解加密服务声明
     */
    private CryptService cryptService;

    /**
     * 获取解加密服务实现类
     *
     * @return CryptService
     */
    private CryptService getCryptService() {
        if (cryptService == null) {
            try {
                cryptService = BeanFactoryHolder.getBean(CryptService.class);
            } catch (Exception e) {
                LOGGER.error("CryptService not found", e);
                throw new RuntimeException("CryptService not found");
            }
        }
        return cryptService;
    }

    /**
     * 解加密服务声明
     */
    private ConversionService conversionService;

    /**
     * 获取解加密服务实现类
     *
     * @return conversionService
     */
    private ConversionService getConversionService() {
        if (conversionService == null) {
            try {
                conversionService = BeanFactoryHolder.getTypeConverter();
            } catch (Exception e) {
                LOGGER.error("TypeConverter not found", e);
                throw new RuntimeException("TypeConverter not found");
            }
        }
        return conversionService;
    }

}
